// Copyright (C) 2018 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//


#include <net_wrappers.h>

namespace MobileSearch {
    Lpr::Lpr(InferenceEngine::Core &ie, const std::string &deviceName, const std::string &xmlPath,
             const bool autoResize, const std::map<std::string, std::string> &pluginConfig,std::string model_lpr_path,bool auto_resize) :
            ie_{ie} {
        InferenceEngine::CNNNetReader LprNetReader;
        LprNetReader.ReadNetwork(model_lpr_path);
        std::string lprBinFileName = fileNameNoExt(model_lpr_path) + ".bin";
        LprNetReader.ReadWeights(lprBinFileName);

        /** LPR network should have 2 inputs (and second is just a stub) and one output **/
        // ---------------------------Check inputs ------------------------------------------------------
        InferenceEngine::InputsDataMap LprInputInfo(LprNetReader.getNetwork().getInputsInfo());
        if (LprInputInfo.size() != 2) {
            throw std::logic_error("LPR should have 2 inputs");
        }
        InferenceEngine::InputInfo::Ptr &LprInputInfoFirst = LprInputInfo.begin()->second;
        LprInputInfoFirst->setPrecision(InferenceEngine::Precision::U8);
        if (auto_resize) {
            LprInputInfoFirst->getPreProcess().setResizeAlgorithm(InferenceEngine::ResizeAlgorithm::RESIZE_BILINEAR);
            LprInputInfoFirst->setLayout(InferenceEngine::Layout::NHWC);
        } else {
            LprInputInfoFirst->setLayout(InferenceEngine::Layout::NCHW);
        }
        LprInputName = LprInputInfo.begin()->first;
        auto sequenceInput = (++LprInputInfo.begin());
        LprInputSeqName = sequenceInput->first;
        maxSequenceSizePerPlate = sequenceInput->second->getTensorDesc().getDims()[0];
        // -----------------------------------------------------------------------------------------------------

        // ---------------------------Check outputs ------------------------------------------------------
        InferenceEngine::OutputsDataMap LprOutputInfo(LprNetReader.getNetwork().getOutputsInfo());
        if (LprOutputInfo.size() != 1) {
            throw std::logic_error("LPR should have 1 output");
        }
        LprOutputName = LprOutputInfo.begin()->first;

        net = ie_.LoadNetwork(LprNetReader.getNetwork(), deviceName, pluginConfig);

    }

    InferenceEngine::InferRequest Lpr::createInferRequest() {
        return net.CreateInferRequest();
    }

    void Lpr::setImage(InferenceEngine::InferRequest &inferRequest, const cv::Mat &img, const cv::Rect plateRect) {
        InferenceEngine::Blob::Ptr roiBlob = inferRequest.GetBlob(LprInputName);
        if (InferenceEngine::Layout::NHWC == roiBlob->getTensorDesc().getLayout()) {  // autoResize is set
            InferenceEngine::ROI cropRoi{0, static_cast<size_t>(plateRect.x), static_cast<size_t>(plateRect.y),
                                         static_cast<size_t>(plateRect.width),
                                         static_cast<size_t>(plateRect.height)};
            InferenceEngine::Blob::Ptr frameBlob = wrapMat2Blob(img);
            InferenceEngine::Blob::Ptr roiBlob = make_shared_blob(frameBlob, cropRoi);
            inferRequest.SetBlob(LprInputName, roiBlob);
        } else {
            const cv::Mat &vehicleImage = img(plateRect);
            matU8ToBlob<uint8_t>(vehicleImage, roiBlob);
        }
        InferenceEngine::Blob::Ptr seqBlob = inferRequest.GetBlob(LprInputSeqName);
        // second input is sequence, which is some relic from the training
        // it should have the leading 0.0f and rest 1.0f
        float *blob_data = seqBlob->buffer().as<float *>();
        blob_data[0] = 0.0f;
        std::fill(blob_data + 1, blob_data + maxSequenceSizePerPlate, 1.0f);
    }

    std::string Lpr::getResults(InferenceEngine::InferRequest &inferRequest) {
        static const std::vector<std::string> items = {
                "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
                "<Anhui>", "<Beijing>", "<Chongqing>", "<Fujian>",
                "<Gansu>", "<Guangdong>", "<Guangxi>", "<Guizhou>",
                "<Hainan>", "<Hebei>", "<Heilongjiang>", "<Henan>",
                "<HongKong>", "<Hubei>", "<Hunan>", "<InnerMongolia>",
                "<Jiangsu>", "<Jiangxi>", "<Jilin>", "<Liaoning>",
                "<Macau>", "<Ningxia>", "<Qinghai>", "<Shaanxi>",
                "<Shandong>", "<Shanghai>", "<Shanxi>", "<Sichuan>",
                "<Tianjin>", "<Tibet>", "<Xinjiang>", "<Yunnan>",
                "<Zhejiang>", "<police>",
                "A", "B", "C", "D", "E", "F", "G", "H", "I", "J",
                "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T",
                "U", "V", "W", "X", "Y", "Z"
        };
        std::string result;
        result.reserve(14u + 6u);  // the longest province name + 6 plate signs
        // up to 88 items per license plate, ended with "-1"
        const auto data = inferRequest.GetBlob(LprOutputName)->buffer().as<float *>();
        for (int i = 0; i < maxSequenceSizePerPlate; i++) {
            if (data[i] == -1) {
                break;
            }
            result += items[static_cast<std::vector<std::string>::size_type>(data[i])];
        }
        return result;
    }
}



